% Compute Bayes Factor

clear all
small = 1.0e-10;
pinv_tol = 1.0e-05;
big = 1.0e+8;

% include path for utilities/functions
addpath('../Matlab_Functions');

% -- File Directories   
outdir = 'out/';
figdir = 'fig/';
matdir = 'mat/';

% Identifiers for files 
% -------------------------- Firm Size --------------
application_str = 'FirmSize.';
k = 30;
application_dir = '/Users/mwatson/Dropbox/TVExtreme/ReplicationFiles/Stan/FirmSize/';
stan_draws_dir = [application_dir 'CSV/'];
model_str = 'GEV_4RW';
data_fname = [matdir 'FirmSize_top100_byYear.mat'];
load(data_fname);
emp_all = emp_largest_normalized;
yr = calvec;
yr_use = [1950 1973 1996 2019]';
yr = calvec;
k = 30;
emp_use = NaN(length(yr_use),k);
for i = 1:length(yr_use)
    yr_idx = find(yr==yr_use(i));
    emp_use(i,:) = emp_all(yr_idx,1:k);
end
Y_data = emp_use;
T = size(Y_data,1);
k = size(Y_data,2);
% -------------------------------------------------------


% Models to construct Bayes Factors
str_mod_vec = {'_v_constant' ...
               '_v_alpha' ...
               '_v_xi_sigma_mu' ...
                }' ; 

% Save Results to output file
outfname = [outdir application_str model_str '_BayesFactors.txt'];
fid = fopen(outfname,'w');
fprintf(fid,'Data: %s\n',data_fname);
fprintf(fid,'T = %3i \n',size(Y_data,1));
fprintf(fid,'k = %3i \n\n',size(Y_data,2));


for i = 1:size(str_mod_vec,1)
    str_mod = str_mod_vec{i};
    name_str = [stan_draws_dir application_str model_str str_mod];
    tic;
    % Load Stan Results
    csv_str_a = [name_str '_Model_1.Draws'];
    csv_str_b = [name_str '_Model_2.Draws'];

    tmp = readmatrix([csv_str_a '.xi_1.csv']); xi1_a = tmp(2:end,2:end);  % Eliminate first row and first column 
    tmp = readmatrix([csv_str_a '.sigma_1.csv']); sigma1_a = tmp(2:end,2:end);  % Eliminate first row and first column
    tmp = readmatrix([csv_str_a '.mu_1.csv']); mu1_a = tmp(2:end,2:end);  % Eliminate first row and first column 
    tmp = readmatrix([csv_str_a '.xi_2.csv']); xi2_a = tmp(2:end,2:end);  % Eliminate first row and first column
    tmp = readmatrix([csv_str_a '.sigma_2.csv']); sigma2_a = tmp(2:end,2:end);  % Eliminate first row and first column 
    tmp = readmatrix([csv_str_a '.mu_2.csv']); mu2_a = tmp(2:end,2:end);  % Eliminate first row and first column

    tmp = readmatrix([csv_str_b '.xi_1.csv']); xi1_b = tmp(2:end,2:end);  % Eliminate first row and first column 
    tmp = readmatrix([csv_str_b '.sigma_1.csv']); sigma1_b = tmp(2:end,2:end);  % Eliminate first row and first column
    tmp = readmatrix([csv_str_b '.mu_1.csv']); mu1_b = tmp(2:end,2:end);  % Eliminate first row and first column 
    tmp = readmatrix([csv_str_b '.xi_2.csv']); xi2_b = tmp(2:end,2:end);  % Eliminate first row and first column
    tmp = readmatrix([csv_str_b '.sigma_2.csv']); sigma2_b = tmp(2:end,2:end);  % Eliminate first row and first column 
    tmp = readmatrix([csv_str_b '.mu_2.csv']); mu2_b = tmp(2:end,2:end);  % Eliminate first row and first column

    nrep = size(xi1_a,1);
    
    % Construct Log-likelihoods: L_x_y is the log-likelihood for model x using model y parameter draws 
    L_A_A = zeros(nrep,1); 
    L_B_A = zeros(nrep,1);
    L_A_B = zeros(nrep,1);
    L_B_B = zeros(nrep,1);

    for irep = 1: nrep
        parms_a_a = [xi1_a(irep,:)' sigma1_a(irep,:)' mu1_a(irep,:)'];
        parms_a_b = [xi1_b(irep,:)' sigma1_b(irep,:)' mu1_b(irep,:)'];
        parms_b_a = [xi2_a(irep,:)' sigma2_a(irep,:)' mu2_a(irep,:)'];
        parms_b_b = [xi2_b(irep,:)' sigma2_b(irep,:)' mu2_b(irep,:)'];
        [L_A_A(irep)] = get_llf(Y_data,parms_a_a);
        [L_B_A(irep)] = get_llf(Y_data,parms_b_a);
        [L_A_B(irep)] = get_llf(Y_data,parms_a_b);
        [L_B_B(irep)] = get_llf(Y_data,parms_b_b);
    end

    % Compute Bayes Factor
    [BF,log_BF,se_BF,se_log_BF] = get_BF(L_A_A,L_B_A,L_A_B,L_B_B);
    fprintf(fid,['Constrained Model: ' str_mod '\n']);
    fprintf(fid,'    Bayes Factors (se):  %8.4f (%8.4f) \n',[BF se_BF]);
    fprintf(fid, '    Log(BF) (SE): %8.4f (%8.4f) \n\n\n',[log_BF se_log_BF]);

    toc;

end


fclose(fid);

function llf = get_llf(Y,parms)
    xsi_vec = parms(:,1);
    sigma_vec = parms(:,2);
    mu_vec = parms(:,3);
    llf_vec = llf_GEV_tv_stan(Y,xsi_vec,sigma_vec,mu_vec);
    llf = sum(llf_vec);
end


